{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "style_transfer.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "g_nWetWWd_ns"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "2pHVBk_seED1",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6msVLevwcRhm"
      },
      "source": [
        "# 神经风格迁移"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ds4o1h4WHz9U"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/generative/style_transfer\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorflow.google.cn 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/generative/style_transfer.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 上运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/generative/style_transfer.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 GitHub 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/zh-cn/tutorials/generative/style_transfer.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载此 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GEe3i16tQPjo",
        "colab_type": "text"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aDyGj8DmXCJI"
      },
      "source": [
        "本教程使用深度学习来用其他图像的风格创造一个图像（曾经你是否希望可以像毕加索或梵高一样绘画？）。 这被称为*神经风格转移*，该技术概述于 <a href=\"https://arxiv.org/abs/1508.06576\" class=\"external\">A Neural Algorithm of Artistic Style</a> (Gatys et al.). \n",
        "\n",
        "神经风格迁移是一种优化技术，用于将两个图像——一个*内容*图像和一个*风格参考*图像（如著名画家的一个作品）——混合在一起，使输出的图像看起来像内容图像， 但是用了风格参考图像的风格。\n",
        "\n",
        "这是通过优化输出图像以匹配内容图像的内容统计数据和风格参考图像的风格统计数据来实现的。 这些统计数据可以使用卷积网络从图像中提取。\n",
        "\n",
        "例如，我们选取这张乌龟的照片和 Wassily Kandinsky 的作品 7：\n",
        "\n",
        "<img src=\"https://github.com/nounotabe/docs/blob/master/site/zh-cn/tutorials/generative/images/Green_Sea_Turtle_grazing_seagrass.jpg?raw=1\" style=\"width: 500px;\"/>\n",
        "\n",
        "[绿海龟照片](https://commons.wikimedia.org/wiki/File:Green_Sea_Turtle_grazing_seagrass.jpg) -摄影者： P.Lindgren [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0), 来自 Wikimedia Common\n",
        "\n",
        "<img src=\"https://github.com/nounotabe/docs/blob/master/site/zh-cn/tutorials/generative/images/kadinsky.jpg?raw=1\" style=\"width: 500px;\"/>\n",
        "\n",
        "\n",
        "如果 Kandinsky 决定用这种风格来专门描绘这只海龟会是什么样子？ 是否如下图一样？\n",
        "\n",
        "<img src=\"https://github.com/nounotabe/docs/blob/master/site/zh-cn/tutorials/generative/images/kadinsky-turtle.png?raw=1\" style=\"width: 500px;\"/>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "U8ajP_u73s6m"
      },
      "source": [
        "## 配置\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "eqxUicSPUOP6"
      },
      "source": [
        "### 导入和配置模块"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2Mdpou0qzCm6",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "NyftRTSMuwue",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # %tensorflow_version only exists in Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "sc1OLbOWhPCO",
        "colab": {}
      },
      "source": [
        "import IPython.display as display\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib as mpl\n",
        "mpl.rcParams['figure.figsize'] = (12,12)\n",
        "mpl.rcParams['axes.grid'] = False\n",
        "\n",
        "import numpy as np\n",
        "import time\n",
        "import functools"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oeXebYusyHwC"
      },
      "source": [
        "下载图像并选择风格图像和内容图像："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wqc0OJHwyFAk",
        "colab": {}
      },
      "source": [
        "content_path = tf.keras.utils.get_file('turtle.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/Green_Sea_Turtle_grazing_seagrass.jpg')\n",
        "style_path = tf.keras.utils.get_file('kandinsky.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xE4Yt8nArTeR"
      },
      "source": [
        "## 将输入可视化"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "klh6ObK2t_vH"
      },
      "source": [
        "定义一个加载图像的函数，并将其最大尺寸限制为 512 像素。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3TLljcwv5qZs",
        "colab": {}
      },
      "source": [
        "def load_img(path_to_img):\n",
        "  max_dim = 512\n",
        "  img = tf.io.read_file(path_to_img)\n",
        "  img = tf.image.decode_image(img, channels=3)\n",
        "  img = tf.image.convert_image_dtype(img, tf.float32)\n",
        "\n",
        "  shape = tf.cast(tf.shape(img)[:-1], tf.float32)\n",
        "  long_dim = max(shape)\n",
        "  scale = max_dim / long_dim\n",
        "\n",
        "  new_shape = tf.cast(shape * scale, tf.int32)\n",
        "\n",
        "  img = tf.image.resize(img, new_shape)\n",
        "  img = img[tf.newaxis, :]\n",
        "  return img"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2yAlRzJZrWM3"
      },
      "source": [
        "创建一个简单的函数来显示图像："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "cBX-eNT8PAK_",
        "colab": {}
      },
      "source": [
        "def imshow(image, title=None):\n",
        "  if len(image.shape) > 3:\n",
        "    image = tf.squeeze(image, axis=0)\n",
        "\n",
        "  plt.imshow(image)\n",
        "  if title:\n",
        "    plt.title(title)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_UWQmeEaiKkP",
        "colab": {}
      },
      "source": [
        "content_image = load_img(content_path)\n",
        "style_image = load_img(style_path)\n",
        "\n",
        "plt.subplot(1, 2, 1)\n",
        "imshow(content_image, 'Content Image')\n",
        "\n",
        "plt.subplot(1, 2, 2)\n",
        "imshow(style_image, 'Style Image')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GEwZ7FlwrjoZ"
      },
      "source": [
        "## 定义内容和风格的表示\n",
        "\n",
        "使用模型的中间层来获取图像的*内容*和*风格*表示。 从网络的输入层开始，前几个层的激励响应表示边缘和纹理等低级 feature (特征)。 随着层数加深，最后几层代表更高级的 feature (特征)——实体的部分，如*轮子*或*眼睛*。 在此教程中，我们使用的是 VGG19 网络结构，这是一个已经预训练好的图像分类网络。 这些中间层是从图像中定义内容和风格的表示所必需的。 对于一个输入图像，我们尝试匹配这些中间层的相应风格和内容目标的表示。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LP_7zrziuiJk"
      },
      "source": [
        "加载 [VGG19](https://keras.io/applications/#vgg19) 并在我们的图像上测试它以确保正常运行："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fMbzrr7BCTq0",
        "colab": {}
      },
      "source": [
        "x = tf.keras.applications.vgg19.preprocess_input(content_image*255)\n",
        "x = tf.image.resize(x, (224, 224))\n",
        "vgg = tf.keras.applications.VGG19(include_top=True, weights='imagenet')\n",
        "prediction_probabilities = vgg(x)\n",
        "prediction_probabilities.shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "1_FyCm0dYnvl",
        "colab": {}
      },
      "source": [
        "predicted_top_5 = tf.keras.applications.vgg19.decode_predictions(prediction_probabilities.numpy())[0]\n",
        "[(class_name, prob) for (number, class_name, prob) in predicted_top_5]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ljpoYk-0f6HS"
      },
      "source": [
        "现在，加载没有分类部分的 `VGG19` ，并列出各层的名称："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Yh_AV6220ebD",
        "colab": {}
      },
      "source": [
        "vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')\n",
        "\n",
        "print()\n",
        "for layer in vgg.layers:\n",
        "  print(layer.name)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Wt-tASys0eJv"
      },
      "source": [
        "从网络中选择中间层的输出以表示图像的风格和内容：\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ArfX_6iA0WAX",
        "colab": {}
      },
      "source": [
        "# 内容层将提取出我们的 feature maps （特征图）\n",
        "content_layers = ['block5_conv2'] \n",
        "\n",
        "# 我们感兴趣的风格层\n",
        "style_layers = ['block1_conv1',\n",
        "                'block2_conv1',\n",
        "                'block3_conv1', \n",
        "                'block4_conv1', \n",
        "                'block5_conv1']\n",
        "\n",
        "num_content_layers = len(content_layers)\n",
        "num_style_layers = len(style_layers)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2o4nSwuN0U3X"
      },
      "source": [
        "#### 用于表示风格和内容的中间层\n",
        "\n",
        "那么,为什么我们预训练的图像分类网络中的这些中间层的输出允许我们定义风格和内容的表示？\n",
        "\n",
        "从高层理解，为了使网络能够实现图像分类（该网络已被训练过），它必须理解图像。 这需要将原始图像作为输入像素并构建内部表示，这个内部表示将原始图像像素转换为对图像中存在的 feature (特征)的复杂理解。\n",
        "\n",
        "这也是卷积神经网络能够很好地推广的一个原因：它们能够捕获不变性并定义类别（例如猫与狗）之间的 feature (特征)，这些 feature (特征)与背景噪声和其他干扰无关。 因此，将原始图像传递到模型输入和分类标签输出之间的某处的这一过程，可以视作复杂的 feature (特征)提取器。通过这些模型的中间层，我们就可以描述输入图像的内容和风格。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Jt3i3RRrJiOX"
      },
      "source": [
        "## 建立模型 \n",
        "\n",
        "使用`tf.keras.applications`中的网络可以让我们非常方便的利用Keras的功能接口提取中间层的值。\n",
        "\n",
        "在使用功能接口定义模型时，我们需要指定输入和输出：\n",
        "\n",
        "`model = Model(inputs, outputs)`\n",
        "\n",
        "以下函数构建了一个 VGG19 模型，该模型返回一个中间层输出的列表："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "nfec6MuMAbPx",
        "colab": {}
      },
      "source": [
        "def vgg_layers(layer_names):\n",
        "  \"\"\" Creates a vgg model that returns a list of intermediate output values.\"\"\"\n",
        "  # 加载我们的模型。 加载已经在 imagenet 数据上预训练的 VGG \n",
        "  vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')\n",
        "  vgg.trainable = False\n",
        "  \n",
        "  outputs = [vgg.get_layer(name).output for name in layer_names]\n",
        "\n",
        "  model = tf.keras.Model([vgg.input], outputs)\n",
        "  return model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jbaIvZf5wWn_"
      },
      "source": [
        "然后建立模型："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LkyvPpBHSfVi",
        "colab": {}
      },
      "source": [
        "style_extractor = vgg_layers(style_layers)\n",
        "style_outputs = style_extractor(style_image*255)\n",
        "\n",
        "#查看每层输出的统计信息\n",
        "for name, output in zip(style_layers, style_outputs):\n",
        "  print(name)\n",
        "  print(\"  shape: \", output.numpy().shape)\n",
        "  print(\"  min: \", output.numpy().min())\n",
        "  print(\"  max: \", output.numpy().max())\n",
        "  print(\"  mean: \", output.numpy().mean())\n",
        "  print()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lGUfttK9F8d5"
      },
      "source": [
        "## 风格计算\n",
        "\n",
        "图像的内容由中间 feature maps (特征图)的值表示。\n",
        "\n",
        "事实证明，图像的风格可以通过不同 feature maps (特征图)上的平均值和相关性来描述。 通过在每个位置计算 feature (特征)向量的外积，并在所有位置对该外积进行平均,可以计算出包含此信息的 Gram 矩阵。 对于特定层的 Gram 矩阵，具体计算方法如下所示：\n",
        "\n",
        "$$G^l_{cd} = \\frac{\\sum_{ij} F^l_{ijc}(x)F^l_{ijd}(x)}{IJ}$$\n",
        "\n",
        "这可以使用`tf.linalg.einsum`函数来实现："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HAy1iGPdoEpZ",
        "colab": {}
      },
      "source": [
        "def gram_matrix(input_tensor):\n",
        "  result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)\n",
        "  input_shape = tf.shape(input_tensor)\n",
        "  num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)\n",
        "  return result/(num_locations)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pXIUX6czZABh"
      },
      "source": [
        "## 提取风格和内容\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1HGHvwlJ1nkn"
      },
      "source": [
        "构建一个返回风格和内容张量的模型。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Sr6QALY-I1ja",
        "colab": {}
      },
      "source": [
        "class StyleContentModel(tf.keras.models.Model):\n",
        "  def __init__(self, style_layers, content_layers):\n",
        "    super(StyleContentModel, self).__init__()\n",
        "    self.vgg =  vgg_layers(style_layers + content_layers)\n",
        "    self.style_layers = style_layers\n",
        "    self.content_layers = content_layers\n",
        "    self.num_style_layers = len(style_layers)\n",
        "    self.vgg.trainable = False\n",
        "\n",
        "  def call(self, inputs):\n",
        "    \"Expects float input in [0,1]\"\n",
        "    inputs = inputs*255.0\n",
        "    preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)\n",
        "    outputs = self.vgg(preprocessed_input)\n",
        "    style_outputs, content_outputs = (outputs[:self.num_style_layers], \n",
        "                                      outputs[self.num_style_layers:])\n",
        "\n",
        "    style_outputs = [gram_matrix(style_output)\n",
        "                     for style_output in style_outputs]\n",
        "\n",
        "    content_dict = {content_name:value \n",
        "                    for content_name, value \n",
        "                    in zip(self.content_layers, content_outputs)}\n",
        "\n",
        "    style_dict = {style_name:value\n",
        "                  for style_name, value\n",
        "                  in zip(self.style_layers, style_outputs)}\n",
        "    \n",
        "    return {'content':content_dict, 'style':style_dict}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Xuj1o33t1edl"
      },
      "source": [
        "在图像上调用此模型，可以返回 style_layers 的 gram 矩阵（风格）和 content_layers 的内容："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rkjO-DoNDU0A",
        "colab": {}
      },
      "source": [
        "extractor = StyleContentModel(style_layers, content_layers)\n",
        "\n",
        "results = extractor(tf.constant(content_image))\n",
        "\n",
        "style_results = results['style']\n",
        "\n",
        "print('Styles:')\n",
        "for name, output in sorted(results['style'].items()):\n",
        "  print(\"  \", name)\n",
        "  print(\"    shape: \", output.numpy().shape)\n",
        "  print(\"    min: \", output.numpy().min())\n",
        "  print(\"    max: \", output.numpy().max())\n",
        "  print(\"    mean: \", output.numpy().mean())\n",
        "  print()\n",
        "\n",
        "print(\"Contents:\")\n",
        "for name, output in sorted(results['content'].items()):\n",
        "  print(\"  \", name)\n",
        "  print(\"    shape: \", output.numpy().shape)\n",
        "  print(\"    min: \", output.numpy().min())\n",
        "  print(\"    max: \", output.numpy().max())\n",
        "  print(\"    mean: \", output.numpy().mean())\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "y9r8Lyjb_m0u"
      },
      "source": [
        "## 梯度下降\n",
        "\n",
        "使用此风格和内容提取器，我们现在可以实现风格传输算法。我们通过计算每个图像的输出和目标的均方误差来做到这一点，然后取这些损失值的加权和。\n",
        "\n",
        "设置风格和内容的目标值："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "PgkNOnGUFcKa",
        "colab": {}
      },
      "source": [
        "style_targets = extractor(style_image)['style']\n",
        "content_targets = extractor(content_image)['content']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CNPrpl-e_w9A"
      },
      "source": [
        "定义一个 `tf.Variable` 来表示要优化的图像。 为了快速实现这一点，使用内容图像对其进行初始化（ `tf.Variable` 必须与内容图像的形状相同）"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "J0vKxF8ZO6G8",
        "colab": {}
      },
      "source": [
        "image = tf.Variable(content_image)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "M6L8ojmn_6rH"
      },
      "source": [
        "由于这是一个浮点图像，因此我们定义一个函数来保持像素值在 0 和 1 之间："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kdgpTJwL_vE2",
        "colab": {}
      },
      "source": [
        "def clip_0_1(image):\n",
        "  return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MBU5RFpcAo7W"
      },
      "source": [
        "创建一个 optimizer 。 本教程推荐 LBFGS，但 `Adam` 也可以正常工作："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "r4XZjqUk_5Eu",
        "colab": {}
      },
      "source": [
        "opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "As-evbBiA2qT"
      },
      "source": [
        "为了优化它，我们使用两个损失的加权组合来获得总损失："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Dt4pxarvA4I4",
        "colab": {}
      },
      "source": [
        "style_weight=1e-2\n",
        "content_weight=1e4"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "0ggx2Na8oROH",
        "colab": {}
      },
      "source": [
        "def style_content_loss(outputs):\n",
        "    style_outputs = outputs['style']\n",
        "    content_outputs = outputs['content']\n",
        "    style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2) \n",
        "                           for name in style_outputs.keys()])\n",
        "    style_loss *= style_weight / num_style_layers\n",
        "\n",
        "    content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2) \n",
        "                             for name in content_outputs.keys()])\n",
        "    content_loss *= content_weight / num_content_layers\n",
        "    loss = style_loss + content_loss\n",
        "    return loss"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vbF2WnP9BI5M"
      },
      "source": [
        "使用 `tf.GradientTape` 来更新图像。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "0t0umkajFIuh",
        "colab": {}
      },
      "source": [
        "@tf.function()\n",
        "def train_step(image):\n",
        "  with tf.GradientTape() as tape:\n",
        "    outputs = extractor(image)\n",
        "    loss = style_content_loss(outputs)\n",
        "\n",
        "  grad = tape.gradient(loss, image)\n",
        "  opt.apply_gradients([(grad, image)])\n",
        "  image.assign(clip_0_1(image))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5FHMJq4UBRIQ"
      },
      "source": [
        "现在，我们运行几个步来测试一下："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Y542mxi-O2a2",
        "colab": {}
      },
      "source": [
        "train_step(image)\n",
        "train_step(image)\n",
        "train_step(image)\n",
        "plt.imshow(image.read_value()[0])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mNzE-mTbBVgY"
      },
      "source": [
        "运行正常，我们来执行一个更长的优化："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rQW1tXYoLbUS",
        "colab": {}
      },
      "source": [
        "import time\n",
        "start = time.time()\n",
        "\n",
        "epochs = 10\n",
        "steps_per_epoch = 100\n",
        "\n",
        "step = 0\n",
        "for n in range(epochs):\n",
        "  for m in range(steps_per_epoch):\n",
        "    step += 1\n",
        "    train_step(image)\n",
        "    print(\".\", end='')\n",
        "  display.clear_output(wait=True)\n",
        "  imshow(image.read_value())\n",
        "  plt.title(\"Train step: {}\".format(step))\n",
        "  plt.show()\n",
        "\n",
        "end = time.time()\n",
        "print(\"Total time: {:.1f}\".format(end-start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GWVB3anJMY2v"
      },
      "source": [
        "## 总变分损失\n",
        "\n",
        "此实现只是一个基础版本，它的一个缺点是它会产生大量的高频误差。 我们可以直接通过正则化图像的高频分量来减少这些高频误差。 在风格转移中，这通常被称为*总变分损失*："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "7szUUybCQMB3",
        "colab": {}
      },
      "source": [
        "def high_pass_x_y(image):\n",
        "  x_var = image[:,:,1:,:] - image[:,:,:-1,:]\n",
        "  y_var = image[:,1:,:,:] - image[:,:-1,:,:]\n",
        "\n",
        "  return x_var, y_var"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Atc2oL29PXu_",
        "colab": {}
      },
      "source": [
        "x_deltas, y_deltas = high_pass_x_y(content_image)\n",
        "\n",
        "plt.figure(figsize=(14,10))\n",
        "plt.subplot(2,2,1)\n",
        "imshow(clip_0_1(2*y_deltas+0.5), \"Horizontal Deltas: Original\")\n",
        "\n",
        "plt.subplot(2,2,2)\n",
        "imshow(clip_0_1(2*x_deltas+0.5), \"Vertical Deltas: Original\")\n",
        "\n",
        "x_deltas, y_deltas = high_pass_x_y(image)\n",
        "\n",
        "plt.subplot(2,2,3)\n",
        "imshow(clip_0_1(2*y_deltas+0.5), \"Horizontal Deltas: Styled\")\n",
        "\n",
        "plt.subplot(2,2,4)\n",
        "imshow(clip_0_1(2*x_deltas+0.5), \"Vertical Deltas: Styled\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lqHElVgBkgkz"
      },
      "source": [
        "这显示了高频分量如何增加。\n",
        "\n",
        "而且，本质上高频分量是一个边缘检测器。 我们可以从 Sobel 边缘检测器获得类似的输出，例如："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HyvqCiywiUfL",
        "colab": {}
      },
      "source": [
        "plt.figure(figsize=(14,10))\n",
        "\n",
        "sobel = tf.image.sobel_edges(content_image)\n",
        "plt.subplot(1,2,1)\n",
        "imshow(clip_0_1(sobel[...,0]/4+0.5), \"Horizontal Sobel-edges\")\n",
        "plt.subplot(1,2,2)\n",
        "imshow(clip_0_1(sobel[...,1]/4+0.5), \"Vertical Sobel-edges\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vv5bKlSDnPP7"
      },
      "source": [
        "与此相关的正则化损失是这些值的平方和："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "mP-92lXMIYPn",
        "colab": {}
      },
      "source": [
        "def total_variation_loss(image):\n",
        "  x_deltas, y_deltas = high_pass_x_y(image)\n",
        "  return tf.reduce_mean(x_deltas**2) + tf.reduce_mean(y_deltas**2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nTessd-DCdcC"
      },
      "source": [
        "## 重新进行优化\n",
        "\n",
        "选择 `total_variation_loss` 的权重："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "tGeRLD4GoAd4",
        "colab": {}
      },
      "source": [
        "total_variation_weight=1e8"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kG1-T4kJsoAv"
      },
      "source": [
        "现在，将它加入 `train_step` 函数中："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "BzmfcyyYUyWq",
        "colab": {}
      },
      "source": [
        "@tf.function()\n",
        "def train_step(image):\n",
        "  with tf.GradientTape() as tape:\n",
        "    outputs = extractor(image)\n",
        "    loss = style_content_loss(outputs)\n",
        "    loss += total_variation_weight*total_variation_loss(image)\n",
        "\n",
        "  grad = tape.gradient(loss, image)\n",
        "  opt.apply_gradients([(grad, image)])\n",
        "  image.assign(clip_0_1(image))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lcLWBQChsutQ"
      },
      "source": [
        "重新初始化优化的变量："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "a-dPRr8BqexB",
        "colab": {}
      },
      "source": [
        "image = tf.Variable(content_image)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BEflRstmtGBu"
      },
      "source": [
        "并进行优化："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "q3Cc3bLtoOWy",
        "colab": {}
      },
      "source": [
        "import time\n",
        "start = time.time()\n",
        "\n",
        "epochs = 10\n",
        "steps_per_epoch = 100\n",
        "\n",
        "step = 0\n",
        "for n in range(epochs):\n",
        "  for m in range(steps_per_epoch):\n",
        "    step += 1\n",
        "    train_step(image)\n",
        "    print(\".\", end='')\n",
        "  display.clear_output(wait=True)\n",
        "  imshow(image.read_value())\n",
        "  plt.title(\"Train step: {}\".format(step))\n",
        "  plt.show()\n",
        "\n",
        "end = time.time()\n",
        "print(\"Total time: {:.1f}\".format(end-start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KKox7K46tKxy"
      },
      "source": [
        "最后，保存结果："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SSH6OpyyQn7w",
        "colab": {}
      },
      "source": [
        "file_name = 'kadinsky-turtle.png'\n",
        "mpl.image.imsave(file_name, image[0])\n",
        "\n",
        "try:\n",
        "  from google.colab import files\n",
        "except ImportError:\n",
        "   pass\n",
        "else:\n",
        "  files.download(file_name)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}